{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "42dc95b0",
   "metadata": {},
   "source": [
    "# 提供高级模型的可解释性\n",
    "\n",
    "1. RF、DT、ET、XGBoost、LightGBM模型的特征重要性。\n",
    "2. XGBoost、LightGBM模型的决策过程\n",
    "\n",
    "需要动手能力比较强的同学，需要安装GraphViz。\n",
    "1. 官网下载：https://graphviz.org/download/\n",
    "2. 安装过程记得添加环境变量到当前用户\n",
    "3. 安装python接口：`pip install graphviz`\n",
    "\n",
    "可以参考博客：https://blog.csdn.net/chenkfkevin/article/details/111443521\n",
    "\n",
    "## LR模型输出公式\n",
    "\n",
    "label = 2.9152857686530913  -0.205248 * prob02 +0.687838 * prob03 +0.432564 * prob04 -0.409445 * prob05 -0.386123 * prob06 -0.915558 * prob07 +0.375201 * prob08 +0.226856 * prob09 +2.639113 * prob10 +1.288628 * pred1\n",
    "\n",
    "## 树模型输出FeatureImportance\n",
    "\n",
    "1. XGBoost\n",
    "2. LightGBM\n",
    "3. ExtraTrees\n",
    "4. RandomForest等\n",
    "\n",
    "## 输出XGBoost和LightGBM决策过程\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e01082a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.Manager import onekey_show\n",
    "\n",
    "onekey_show('模型可解释性-SHAP')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c4a2f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import joblib\n",
    "import pandas as pd\n",
    "from sklearn import tree\n",
    "import matplotlib.pyplot as plt\n",
    "import xgboost as xgb\n",
    "import lightgbm as lgb\n",
    "from onekey_algo import get_param_in_cwd, init_CN\n",
    "from onekey_algo.custom.components.comp1 import plot_feature_importance\n",
    "\n",
    "init_CN()\n",
    "COEF_THRESHOLD = 1e-6\n",
    "\n",
    "model_root = r'C:\\Users\\onekey\\Desktop\\onekey_comp\\comp9-Solutions\\sol1. 传统组学-单中心-临床\\models'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbe06140",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "save_dir = get_param_in_cwd('save_dir', 'viz')\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "os.environ['PATH'] += ';' + os.path.join(os.environ['ONEKEY_HOME'], r'onekey_envs\\Library\\graphviz')\n",
    "for pkl in os.listdir(model_root):\n",
    "    try:\n",
    "        model = joblib.load(os.path.join(model_root, pkl))\n",
    "        save_name = os.path.splitext(pkl)[0]\n",
    "        print(model.__class__.__name__, \"，模型位置：\", pkl)\n",
    "        if 'xgboost' in pkl.lower():\n",
    "            plot_feature_importance(model, feature_names=list(model.get_booster().get_score().keys()), show=False)\n",
    "            plt.savefig(os.path.join(save_dir, f\"{save_name}_fimp.svg\"), bbox_inches='tight')\n",
    "            xgb.plot_tree(model, save_name=os.path.join(save_dir, save_name))\n",
    "        elif 'lightgbm' in pkl.lower():\n",
    "            plot_feature_importance(model, feature_names=model.feature_name_, show=False)\n",
    "            plt.savefig(os.path.join(save_dir, f\"{save_name}_fimp.svg\"), bbox_inches='tight')\n",
    "            lgb.plot_tree(model, orientation='vertical', save_name=os.path.join(save_dir, save_name), dpi=300)\n",
    "        elif 'lr' in pkl.lower():\n",
    "            feature_names = list(model.feature_names_in_)\n",
    "            feat_coef = [(feat_name, coef) for feat_name, coef in zip(feature_names, model.coef_[0]) \n",
    "                         if COEF_THRESHOLD is None or abs(coef) > COEF_THRESHOLD]\n",
    "            formula = ' '.join([f\"{coef:+.6f} * {feat_name}\" for feat_name, coef in feat_coef])\n",
    "            score = f\"label = {model.intercept_[0]} {'+' if formula[0] != '-' else ''} {formula}\"\n",
    "            print(score)\n",
    "        else:\n",
    "            r = plot_feature_importance(model, feature_names=list(model.feature_names_in_), show=False)\n",
    "            if r:\n",
    "                plt.savefig(os.path.join(save_dir, f\"{save_name}_fimp.svg\"), bbox_inches='tight')\n",
    "        plt.show()\n",
    "    except Exception as e:\n",
    "        import traceback\n",
    "        traceback.print_exc()\n",
    "        print(f\"可视化{model.__class__.__name__}遇到:{e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a2aba5b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
